import torch
import torch.nn as nn
import torch.nn.functional as F


# define the actor network
class ICN(nn.Module):

    def __init__(self, args):
        super(ICN, self).__init__()
        if args.scenario_name in ['GuessingNumber']:
            self.fc1 = nn.Linear(args.obs_shape, args.hidden_size)
            self.fc2 = nn.Linear(args.hidden_size, args.hidden_size)
            self.message_mlp = nn.Linear(args.message_shape * args.n_agents, args.hidden_size)
        elif args.scenario_name == 'RevealingGoal':
            conv_layers = [
                nn.Conv2d(in_channels=args.obs_shape[0], out_channels=16, kernel_size=3, stride=1, padding=1),
                nn.ReLU(),
                nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1),
                nn.Flatten()
            ]
            self.fc1 = nn.Sequential(*conv_layers)
            self.fc2 = nn.Linear(16 * args.obs_shape[1] * args.obs_shape[2], args.hidden_size)
            self.message_mlp = nn.Linear(args.message_shape * args.n_agents, 16 * args.obs_shape[1] * args.obs_shape[2])
        else:
            raise NotImplementedError
        self.rnn = nn.GRU(args.hidden_size, args.hidden_size, args.rnn_layers)
        self.action_out = nn.Linear(args.hidden_size, args.action_shape)

        self.message_out = nn.Linear(args.hidden_size, args.message_shape)

        self.args = args

    def forward(self, x, message, available_actions, hidden, batch_size=0, eps=0):
        if self.args.scenario_name in ['GuessingNumber']:
            pass
        elif self.args.scenario_name == 'RevealingGoal':
            bs = x.shape[0]
            x = x.reshape(bs, self.args.obs_shape[0], self.args.obs_shape[1], self.args.obs_shape[2])
        else:
            raise NotImplementedError

        x = F.relu(self.fc1(x))
        x = x + F.relu(self.message_mlp(message))
        x = F.relu(self.fc2(x))

        # print(x.shape, hidden.shape)
        if batch_size == 0:
            x = x.unsqueeze(0)
        else:
            x = x.reshape(-1, batch_size, self.args.hidden_size)
        x, hidden = self.rnn(x, hidden)

        if batch_size == 0:
            x = x.squeeze(0)
        else:
            x = x.reshape(-1, self.args.hidden_size)

        u_logits = self.action_out(x)

        legal_adv = (1 + u_logits - u_logits.min()) * available_actions

        greedy_action_id = legal_adv.argmax(dim=-1)

        if eps > 0:
            random_action = available_actions.multinomial(1).squeeze(1)
            rand = torch.rand(greedy_action_id.size(), device=greedy_action_id.device)
            rand = (rand < eps).long()
            action_id = (greedy_action_id * (1 - rand) + random_action * rand).detach().long()
        else:
            action_id = greedy_action_id.detach().long()


        selected_logits = u_logits.gather(1, action_id.unsqueeze(1))


        #TODO: add mask

        m_logits = self.message_out(x)

        if self.args.scenario_name in ['GuessingNumber']:
            mask = self.generate_gn_mask(action_id, m_logits)
        elif self.args.scenario_name == 'RevealingGoal':
            mask = self.generate_rg_mask(action_id, m_logits)
        else:
            raise NotImplementedError


        m_logits = m_logits.masked_fill(mask.bool(), -1e10)

        m_out = F.gumbel_softmax(m_logits, hard=True)

        if torch.isnan(m_out).any():
            print(0,torch.isnan(m_out).any(), torch.isnan(m_logits).any(), torch.isnan(mask).any())
            for i in range(m_out.shape[0]):
                if torch.isnan(m_out[i]).any() or torch.isnan(m_logits[i]).any() or torch.isnan(mask[i]).any():
                    print(m_out[i], m_logits[i], mask[i], action_id[i])
            exit(0)

        return action_id, selected_logits, m_out, hidden


    def generate_gn_mask(self, action_id, m_logits):

        device = m_logits.device  # 获取输入张量 x 的设备
        mask = torch.zeros_like(m_logits, device=device, dtype=torch.float)  # 创建一个和输入 x 相同形状的零张量，并确保在同一个设备上

        mask[action_id[:] == self.args.action_shape - 2, -1] = 1
        mask[action_id[:] == self.args.action_shape - 2, -2] = 1

        mask[action_id[:] == self.args.action_shape - 1, :-1] = 1

        mask[action_id[:] < self.args.action_shape - 2, -1] = 1
        mask[action_id[:] < self.args.action_shape - 2, :-2] = 1

        return mask

    def generate_rg_mask(self, action_id, m_logits):

        device = m_logits.device  # 获取输入张量 x 的设备
        mask = torch.zeros_like(m_logits, device=device, dtype=torch.float)  # 创建一个和输入 x 相同形状的零张量，并确保在同一个设备上

        mask[action_id[:] == self.args.action_shape - 1, -1] = 1
        mask[action_id[:] < self.args.action_shape - 1, :-1] = 1

        return mask
